{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 画像とテキストをVLMで処理する\n",
    "\n",
    "このノートブックでは、`HuggingFaceTB/SmolVLM-Instruct` 4bit量子化モデルを使用して、以下のようなさまざまなマルチモーダルタスクを実行する方法を示します：\n",
    "- 視覚質問応答（VQA）：画像の内容に基づいて質問に答える。\n",
    "- テキスト認識（OCR）：画像内のテキストを抽出して解釈する。\n",
    "- ビデオ理解：連続フレーム解析を通じてビデオを説明する。\n",
    "\n",
    "プロンプトを効果的に構築することで、シーン理解、ドキュメント分析、動的視覚推論など、多くのアプリケーションにモデルを活用できます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Google Colabでの要件をインストール\n",
    "# !pip install transformers datasets trl huggingface_hub bitsandbytes\n",
    "\n",
    "# Hugging Faceに認証\n",
    "from huggingface_hub import notebook_login\n",
    "notebook_login()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`low_cpu_mem_usage` was None, now default to True since model is quantized.\n",
      "You shouldn't move a model that is dispatched using accelerate hooks.\n",
      "Some kwargs in processor config are unused and will not have any effect: image_seq_len. \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'longest_edge': 1536}\n"
     ]
    }
   ],
   "source": [
    "import torch, PIL\n",
    "from transformers import AutoProcessor, AutoModelForVision2Seq, BitsAndBytesConfig\n",
    "from transformers.image_utils import load_image\n",
    "\n",
    "device = (\n",
    "    \"cuda\"\n",
    "    if torch.cuda.is_available()\n",
    "    else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n",
    ")\n",
    "\n",
    "quantization_config = BitsAndBytesConfig(load_in_4bit=True)\n",
    "model_name = \"HuggingFaceTB/SmolVLM-Instruct\"\n",
    "model = AutoModelForVision2Seq.from_pretrained(\n",
    "    model_name,\n",
    "    quantization_config=quantization_config,\n",
    ").to(device)\n",
    "processor = AutoProcessor.from_pretrained(\"HuggingFaceTB/SmolVLM-Instruct\")\n",
    "\n",
    "print(processor.image_processor.size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 画像の処理"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "まず、画像のキャプションを生成し、画像に関する質問に答えることから始めましょう。また、複数の画像を処理する方法も探ります。\n",
    "### 1. 単一画像のキャプション生成"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<img src=\"https://cdn.pixabay.com/photo/2024/11/20/09/14/christmas-9210799_1280.jpg\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://cdn.pixabay.com/photo/2024/11/23/08/18/christmas-9218404_1280.jpg\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from IPython.display import Image, display\n",
    "\n",
    "image_url1 = \"https://cdn.pixabay.com/photo/2024/11/20/09/14/christmas-9210799_1280.jpg\"\n",
    "display(Image(url=image_url1))\n",
    "\n",
    "image_url2 = \"https://cdn.pixabay.com/photo/2024/11/23/08/18/christmas-9218404_1280.jpg\"\n",
    "display(Image(url=image_url2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/duydl/Miniconda3/envs/py310/lib/python3.10/site-packages/bitsandbytes/nn/modules.py:451: UserWarning: Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['User:<image>Can you describe the image?\\nAssistant: The image is a scene of a person walking in a forest. The person is wearing a coat and a cap. The person is holding the hand of another person. The person is walking on a path. The path is covered with dry leaves. The background of the image is a forest with trees.']\n"
     ]
    }
   ],
   "source": [
    "# 画像を1枚読み込む\n",
    "image1 = load_image(image_url1)\n",
    "\n",
    "# 入力メッセージを作成\n",
    "messages = [\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": [\n",
    "            {\"type\": \"image\"},\n",
    "            {\"type\": \"text\", \"text\": \"Can you describe the image?\"}\n",
    "        ]\n",
    "    },\n",
    "]\n",
    "\n",
    "# 入力を準備\n",
    "prompt = processor.apply_chat_template(messages, add_generation_prompt=True)\n",
    "inputs = processor(text=prompt, images=[image1], return_tensors=\"pt\")\n",
    "inputs = inputs.to(device)\n",
    "\n",
    "# 出力を生成\n",
    "generated_ids = model.generate(**inputs, max_new_tokens=500)\n",
    "generated_texts = processor.batch_decode(\n",
    "    generated_ids,\n",
    "    skip_special_tokens=True,\n",
    ")\n",
    "\n",
    "print(generated_texts)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. 複数画像の比較\n",
    "モデルは複数の画像を処理して比較することもできます。2つの画像の共通のテーマを判断してみましょう。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['User:<image>What event do they both represent?\\nAssistant: Christmas.']\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# 画像を読み込む\n",
    "image2 = load_image(image_url2)\n",
    "\n",
    "# 入力メッセージを作成\n",
    "messages = [\n",
    "    # {\n",
    "    #     \"role\": \"user\",\n",
    "    #     \"content\": [\n",
    "    #         {\"type\": \"image\"},\n",
    "    #         {\"type\": \"image\"},\n",
    "    #         {\"type\": \"text\", \"text\": \"Can you describe the two images?\"}\n",
    "    #     ]\n",
    "    # },\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": [\n",
    "            {\"type\": \"image\"},\n",
    "            {\"type\": \"image\"},\n",
    "            {\"type\": \"text\", \"text\": \"What event do they both represent?\"}\n",
    "        ]\n",
    "    },\n",
    "]\n",
    "\n",
    "# 入力を準備\n",
    "prompt = processor.apply_chat_template(messages, add_generation_prompt=True)\n",
    "inputs = processor(text=prompt, images=[image1, image2], return_tensors=\"pt\")\n",
    "inputs = inputs.to(device)\n",
    "\n",
    "# 出力を生成\n",
    "generated_ids = model.generate(**inputs, max_new_tokens=500)\n",
    "generated_texts = processor.batch_decode(\n",
    "    generated_ids,\n",
    "    skip_special_tokens=True,\n",
    ")\n",
    "\n",
    "print(generated_texts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 🔠 テキスト認識（OCR）\n",
    "VLMは画像内のテキストを認識して解釈することもでき、ドキュメント分析などのタスクに適しています。\n",
    "テキストが密集している画像で実験してみてください。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<img src=\"https://cdn.pixabay.com/photo/2020/11/30/19/23/christmas-5792015_960_720.png\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['User:<image>What is written?\\nAssistant: MERRY CHRISTMAS AND A HAPPY NEW YEAR']\n"
     ]
    }
   ],
   "source": [
    "document_image_url = \"https://cdn.pixabay.com/photo/2020/11/30/19/23/christmas-5792015_960_720.png\"\n",
    "display(Image(url=document_image_url))\n",
    "\n",
    "# ドキュメント画像を読み込む\n",
    "document_image = load_image(document_image_url)\n",
    "\n",
    "# 分析のための入力メッセージを作成\n",
    "messages = [\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": [\n",
    "            {\"type\": \"image\"},\n",
    "            {\"type\": \"text\", \"text\": \"What is written?\"}\n",
    "        ]\n",
    "    }\n",
    "]\n",
    "\n",
    "# 入力を準備\n",
    "prompt = processor.apply_chat_template(messages, add_generation_prompt=True)\n",
    "inputs = processor(text=prompt, images=[document_image], return_tensors=\"pt\")\n",
    "inputs = inputs.to(device)\n",
    "\n",
    "# 出力を生成\n",
    "generated_ids = model.generate(**inputs, max_new_tokens=500)\n",
    "generated_texts = processor.batch_decode(\n",
    "    generated_ids,\n",
    "    skip_special_tokens=True,\n",
    ")\n",
    "\n",
    "print(generated_texts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ビデオの処理\n",
    "\n",
    "視覚言語モデル（VLM）は、キーフレームを抽出し、時間順に推論することで間接的にビデオを処理できます。VLMは専用のビデオモデルの時間的モデリング機能を欠いていますが、それでも以下のことが可能です：\n",
    "- サンプリングされたフレームを順次分析することで、アクションやイベントを説明する。\n",
    "- 代表的なキーフレームに基づいてビデオに関する質問に答える。\n",
    "- 複数のフレームのテキスト記述を組み合わせてビデオの内容を要約する。\n",
    "\n",
    "例を使って実験してみましょう：\n",
    "\n",
    "<video width=\"640\" height=\"360\" controls>\n",
    "  <source src=\"https://cdn.pixabay.com/video/2023/10/28/186794-879050032_large.mp4\" type=\"video/mp4\">\n",
    "  Your browser does not support the video tag.\n",
    "</video>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install opencv-python"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Video\n",
    "import cv2\n",
    "import numpy as np\n",
    "\n",
    "def extract_frames(video_path, max_frames=50, target_size=None):\n",
    "    cap = cv2.VideoCapture(video_path)\n",
    "    if not cap.isOpened():\n",
    "        raise ValueError(f\"Could not open video: {video_path}\")\n",
    "    \n",
    "    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n",
    "    frame_indices = np.linspace(0, total_frames - 1, max_frames, dtype=int)\n",
    "\n",
    "    frames = []\n",
    "    for idx in frame_indices:\n",
    "        cap.set(cv2.CAP_PROP_POS_FRAMES, idx)\n",
    "        ret, frame = cap.read()\n",
    "        if ret:\n",
    "            frame = PIL.Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))\n",
    "            if target_size:\n",
    "                frames.append(resize_and_crop(frame, target_size))\n",
    "            else:\n",
    "                frames.append(frame)\n",
    "    cap.release()\n",
    "    return frames\n",
    "\n",
    "def resize_and_crop(image, target_size):\n",
    "    width, height = image.size\n",
    "    scale = target_size / min(width, height)\n",
    "    image = image.resize((int(width * scale), int(height * scale)), PIL.Image.Resampling.LANCZOS)\n",
    "    left = (image.width - target_size) // 2\n",
    "    top = (image.height - target_size) // 2\n",
    "    return image.crop((left, top, left + target_size, top + target_size))\n",
    "\n",
    "# ビデオリンク\n",
    "video_link = \"https://cdn.pixabay.com/video/2023/10/28/186794-879050032_large.mp4\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Response: User: Following are the frames of a video in temporal order.<image>Describe what the woman is doing.\n",
      "Assistant: The woman is hanging an ornament on a Christmas tree.\n"
     ]
    }
   ],
   "source": [
    "question = \"Describe what the woman is doing.\"\n",
    "\n",
    "def generate_response(model, processor, frames, question):\n",
    "\n",
    "    image_tokens = [{\"type\": \"image\"} for _ in frames]\n",
    "    messages = [\n",
    "        {\n",
    "            \"role\": \"user\",\n",
    "            \"content\": [{\"type\": \"text\", \"text\": \"Following are the frames of a video in temporal order.\"}, *image_tokens, {\"type\": \"text\", \"text\": question}]\n",
    "        }\n",
    "    ]\n",
    "    inputs = processor(\n",
    "        text=processor.apply_chat_template(messages, add_generation_prompt=True),\n",
    "        images=frames,\n",
    "        return_tensors=\"pt\"\n",
    "    ).to(model.device)\n",
    "\n",
    "    outputs = model.generate(\n",
    "        **inputs, max_new_tokens=100, num_beams=5, temperature=0.7, do_sample=True, use_cache=True\n",
    "    )\n",
    "    return processor.decode(outputs[0], skip_special_tokens=True)\n",
    "\n",
    "\n",
    "# ビデオからフレームを抽出\n",
    "frames = extract_frames(video_link, max_frames=15, target_size=384)\n",
    "\n",
    "processor.image_processor.size = (384, 384)\n",
    "processor.image_processor.do_resize = False\n",
    "# 応答を生成\n",
    "response = generate_response(model, processor, frames, question)\n",
    "\n",
    "# 結果を表示\n",
    "# print(\"Question:\", question)\n",
    "print(\"Response:\", response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 💐 完了です！\n",
    "\n",
    "このノートブックでは、マルチモーダルタスクのプロンプトをフォーマットする方法など、視覚言語モデル（VLM）を使用する方法を示しました。ここで説明した手順に従うことで、VLMとそのアプリケーションを実験できます。\n",
    "\n",
    "### 次のステップ\n",
    "- VLMのさらなる使用例を試してみてください。\n",
    "- 同僚と協力して、彼らのプルリクエスト（PR）をレビューしてください。\n",
    "- 新しい使用例、例、概念を導入するために、問題を開いたりPRを提出したりして、このコース資料の改善に貢献してください。\n",
    "\n",
    "楽しい探求を！🌟"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
